{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_text(file_path):\n",
    "    with open(file_path) as f:\n",
    "        text = f.read()\n",
    "\n",
    "    char2idx = {c: i + 3 for i, c in enumerate(set(text))}\n",
    "    char2idx['<pad>'] = 0\n",
    "    char2idx['<start>'] = 1\n",
    "    char2idx['<end>'] = 2\n",
    "\n",
    "    vector = np.array([char2idx[char] for char in list(text)])\n",
    "    return vector, char2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector, char2idx = parse_text('shakespeare.txt')\n",
    "idx2char = {i: c for c, i in char2idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "seq_len = 100\n",
    "hidden_dim = 128\n",
    "n_layers = 2\n",
    "beam_width = 5\n",
    "clip_norm = 100.0\n",
    "skip = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cell_fn():\n",
    "    return tf.nn.rnn_cell.ResidualWrapper(\n",
    "        tf.nn.rnn_cell.GRUCell(\n",
    "            hidden_dim, kernel_initializer = tf.orthogonal_initializer()\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "def multi_cell_fn():\n",
    "    return tf.nn.rnn_cell.MultiRNNCell([cell_fn() for _ in range(n_layers)])\n",
    "\n",
    "\n",
    "def clip_grads(loss):\n",
    "    variables = tf.trainable_variables()\n",
    "    grads = tf.gradients(loss, variables)\n",
    "    clipped_grads, _ = tf.clip_by_global_norm(grads, clip_norm)\n",
    "    return zip(clipped_grads, variables)\n",
    "\n",
    "\n",
    "class Model:\n",
    "    def __init__(self, seq_len, vocab_size, hidden_dim):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.batch_size = tf.shape(self.X)[0]\n",
    "        encoder_embeddings = tf.Variable(\n",
    "            tf.random_uniform([vocab_size, hidden_dim], -1, 1)\n",
    "        )\n",
    "        cells = multi_cell_fn()\n",
    "        helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "            inputs = tf.nn.embedding_lookup(encoder_embeddings, self.X),\n",
    "            sequence_length = tf.count_nonzero(self.X, 1, dtype = tf.int32),\n",
    "        )\n",
    "        dense_layer = tf.layers.Dense(vocab_size)\n",
    "        decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "            cell = cells,\n",
    "            helper = helper,\n",
    "            initial_state = cells.zero_state(self.batch_size, tf.float32),\n",
    "            output_layer = dense_layer,\n",
    "        )\n",
    "\n",
    "        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "            decoder = decoder\n",
    "        )\n",
    "        self.logits = decoder_output.rnn_output\n",
    "\n",
    "        decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n",
    "            cell = cells,\n",
    "            embedding = encoder_embeddings,\n",
    "            start_tokens = tf.tile(\n",
    "                tf.constant([char2idx['<start>']], dtype = tf.int32), [1]\n",
    "            ),\n",
    "            end_token = char2idx['<end>'],\n",
    "            initial_state = tf.contrib.seq2seq.tile_batch(\n",
    "                cells.zero_state(1, tf.float32), beam_width\n",
    "            ),\n",
    "            beam_width = beam_width,\n",
    "            output_layer = dense_layer,\n",
    "        )\n",
    "\n",
    "        decoder_out, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "            decoder = decoder, maximum_iterations = seq_len\n",
    "        )\n",
    "\n",
    "        self.predict = decoder_out.predicted_ids[:, :, 0]\n",
    "\n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.contrib.seq2seq.sequence_loss(\n",
    "                logits = self.logits,\n",
    "                targets = self.Y,\n",
    "                weights = tf.to_float(tf.ones_like(self.Y)),\n",
    "            )\n",
    "        )\n",
    "        self.global_step = tf.Variable(0, trainable = False)\n",
    "        self.optimizer = tf.train.AdamOptimizer().apply_gradients(\n",
    "            clip_grads(self.cost), global_step = self.global_step\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def start_sentence(x):\n",
    "    _x = np.full([x.shape[0], 1], char2idx['<start>'])\n",
    "    return np.concatenate([_x, x], 1)\n",
    "\n",
    "\n",
    "def end_sentence(x):\n",
    "    _x = np.full([x.shape[0], 1], char2idx['<end>'])\n",
    "    return np.concatenate([x, _x], 1)\n",
    "\n",
    "\n",
    "batches = []\n",
    "for i in range(0, len(vector) - seq_len, skip):\n",
    "    batches.append(vector[i : i + seq_len])\n",
    "X = np.array(batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(seq_len, len(char2idx), hidden_dim)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.39it/s, cost=1.97]\n",
      "train minibatch loop:   0%|          | 1/436 [00:00<00:44,  9.83it/s, cost=1.98]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "there of thou have of thou have of thou have of thou have of thou have thou have of thee\n",
      "And thou ha\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.58it/s, cost=1.79]\n",
      "train minibatch loop:   0%|          | 2/436 [00:00<00:40, 10.74it/s, cost=1.81]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "therefore of thee,\n",
      "And thou have thou have thou have thou have me,\n",
      "And though thou have me thee <end><end><end>\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.32it/s, cost=1.68]\n",
      "train minibatch loop:   0%|          | 2/436 [00:00<00:40, 10.67it/s, cost=1.71]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "t thou art thou hast of thee\n",
      "And thou art thou art thou art thou hast of the done.\n",
      "\n",
      "PETRUCHIO:\n",
      "<end><end><end><end>\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.14it/s, cost=1.62]\n",
      "train minibatch loop:   0%|          | 1/436 [00:00<00:43,  9.92it/s, cost=1.65]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "think thee, sir, and thou hast thou hast of the duke.\n",
      "\n",
      "PROSPERO:\n",
      "What is not thou art thou art <end><end><end><end>\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.46it/s, cost=1.56]\n",
      "train minibatch loop:   0%|          | 0/436 [00:00<?, ?it/s, cost=1.56]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "think thee, sir.\n",
      "\n",
      "PROSPERO:\n",
      "What, sir, thou hast thou hast thou hast thou hast\n",
      "And thou art thou art\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.49it/s, cost=1.53]\n",
      "train minibatch loop:   0%|          | 2/436 [00:00<00:40, 10.75it/s, cost=1.57]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "think your house.\n",
      "\n",
      "PROSPERO:\n",
      "What is not thou hast thou hast thou hast of thee.\n",
      "\n",
      "PROSPERO:\n",
      "Where is \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.47it/s, cost=1.49]\n",
      "train minibatch loop:   0%|          | 2/436 [00:00<00:40, 10.64it/s, cost=1.55]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "thou shall not thou hast thou hast\n",
      "That thou hast thou hast thou hast thou hast thou hast\n",
      "And think \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.52it/s, cost=1.47]\n",
      "train minibatch loop:   0%|          | 2/436 [00:00<00:41, 10.40it/s, cost=1.53]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " thou hast thou hast thou hast thou hast\n",
      "That thou hast thou hast thou hast thou hast thou art.\n",
      "\n",
      "PRO\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.52it/s, cost=1.45]\n",
      "train minibatch loop:   0%|          | 2/436 [00:00<00:39, 10.87it/s, cost=1.51]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " thou hast thou hast thou hast thou hast\n",
      "And thou art thou hast thou didst thou hast strange.\n",
      "\n",
      "PROSP\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 436/436 [00:41<00:00, 10.52it/s, cost=1.44]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " thou hast thou hast thou hast thou hast\n",
      "That thou hast thou art thou art thou art thou ar<end><end><end><end><end>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "for e in range(10):\n",
    "    lasttime = time.time()\n",
    "    train_loss, test_loss = 0, 0\n",
    "    pbar = tqdm(range(0, len(X), batch_size), desc = 'train minibatch loop')\n",
    "    for i in pbar:\n",
    "        batch_x = X[i : min(i + batch_size, len(X))]\n",
    "        batch_y = end_sentence(batch_x)\n",
    "        batch_x = start_sentence(batch_x)\n",
    "        loss, _ = sess.run(\n",
    "            [model.cost, model.optimizer],\n",
    "            feed_dict = {model.X: batch_x, model.Y: batch_y},\n",
    "        )\n",
    "        assert not np.isnan(loss)\n",
    "        train_loss += loss\n",
    "        pbar.set_postfix(cost = loss)\n",
    "\n",
    "    batch_x = start_sentence(X[:batch_size])\n",
    "    ints = sess.run(model.predict, feed_dict = {model.X: batch_x})[0]\n",
    "    print('\\n' + ''.join([idx2char[i] for i in ints]) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
